package com.jstarcraft.ai.jsat.linear;

import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.NoSuchElementException;

/**
 * ConcatenatedVec provides a light wrapper around a list of vectors to provide
 * a view of one single vector that's length is the sum of the lengths of the
 * inputs.
 * 
 * @author Edward Raff
 */
public class ConcatenatedVec extends Vec {

    private static final long serialVersionUID = -1412322616974470550L;
    private Vec[] vecs;
    private int[] lengthSums;
    private int totalLength;

    /**
     * Creates a new Vector that is the concatenation of the given vectors in the
     * given order. The vector created is backed by the ones provided, and any
     * mutation to one is visible in the others.
     * 
     * @param vecs the list of vectors to concatenate
     */
    public ConcatenatedVec(List<Vec> vecs) {
        this.vecs = new Vec[vecs.size()];
        lengthSums = new int[vecs.size()];
        totalLength = 0;
        for (int i = 0; i < vecs.size(); i++) {
            lengthSums[i] = totalLength;
            this.vecs[i] = vecs.get(i);
            totalLength += vecs.get(i).length();
        }
    }

    /**
     * Creates a new Vector that is the concatenation of the given vectors in the
     * given order. The vector created is backed by the ones provided, and any
     * mutation to one is visible in the others.
     * 
     * @param vecs the array of vectors to concatenate
     */
    public ConcatenatedVec(Vec... vecs) {
        this(Arrays.asList(vecs));
    }

    @Override
    public int length() {
        return totalLength;
    }

    @Override
    public double get(int index) {
        int baseIndex = getBaseIndex(index);
        return vecs[baseIndex].get(index - lengthSums[baseIndex]);
    }

    @Override
    public void set(int index, double val) {
        int baseIndex = getBaseIndex(index);
        vecs[baseIndex].set(index - lengthSums[baseIndex], val);
    }

    // The following are implemented only for performance reasons

    @Override
    public void increment(int index, double val) {
        int baseIndex = getBaseIndex(index);
        vecs[baseIndex].increment(index - lengthSums[baseIndex], val);
    }

    @Override
    public int nnz() {
        int nnz = 0;
        for (Vec v : vecs)
            nnz += v.nnz();
        return nnz;
    }

    @Override
    public void mutableAdd(double c, Vec b) {
        for (int i = 0; i < vecs.length; i++) {
            vecs[i].mutableAdd(c, new SubVector(lengthSums[i], vecs[i].length(), b));
        }
    }

    @Override
    public Iterator<IndexValue> getNonZeroIterator(final int start) {
        return new Iterator<IndexValue>() {
            int baseIndex = -1;
            IndexValue valToSend = new IndexValue(0, 0);
            Iterator<IndexValue> curIter = null;
            IndexValue nextValue = null;

            @Override
            public boolean hasNext() {
                if (baseIndex == -1)// initialize everything
                {
                    baseIndex = getBaseIndex(start);
                    int curIndexConsidering = start;
                    // Keep moving till we
                    while (baseIndex < vecs.length && !vecs[baseIndex].getNonZeroIterator(curIndexConsidering - lengthSums[baseIndex]).hasNext()) {
                        baseIndex++;
                        if (baseIndex < vecs.length)
                            curIndexConsidering = lengthSums[baseIndex];

                    }
                    if (baseIndex >= vecs.length)
                        return false;// All zeros beyond this point
                    curIter = vecs[baseIndex].getNonZeroIterator(curIndexConsidering - lengthSums[baseIndex]);
                    nextValue = curIter.next();
                    return true;
                } else
                    return nextValue != null;
            }

            @Override
            public IndexValue next() {
                if (nextValue == null)
                    throw new NoSuchElementException();
                valToSend.setIndex(nextValue.getIndex() + lengthSums[baseIndex]);
                valToSend.setValue(nextValue.getValue());

                if (curIter.hasNext())
                    nextValue = curIter.next();
                else {
                    baseIndex++;
                    while (baseIndex < vecs.length && !(curIter = vecs[baseIndex].getNonZeroIterator()).hasNext())// keep moving till with find a non empty vec
                        baseIndex++;
                    if (baseIndex >= vecs.length)// we have run out
                    {
                        nextValue = null;
                        curIter = null;
                    } else {
                        nextValue = curIter.next();
                    }
                }

                return valToSend;
            }

            @Override
            public void remove() {
                throw new UnsupportedOperationException("Not supported yet."); // To change body of generated methods, choose Tools | Templates.
            }
        };
    }

    @Override
    public boolean isSparse() {
        for (Vec v : vecs)
            if (v.isSparse())
                return true;
        return false;
    }

    @Override
    public ConcatenatedVec clone() {
        Vec[] newVecs = new Vec[vecs.length];
        for (int i = 0; i < vecs.length; i++)
            newVecs[i] = vecs[i].clone();
        return new ConcatenatedVec(Arrays.asList(newVecs));
    }

    private int getBaseIndex(int index) {
        int basIndex = Arrays.binarySearch(lengthSums, index);
        if (basIndex < 0)
            basIndex = (-(basIndex) - 2);// -1 extra b/c we want to be on the lesser side
        return basIndex;
    }

    @Override
    public void setLength(int length) {
        if (length < 0)
            throw new ArithmeticException("Can not create an array of negative length");
        int toAdd = length - length();
        int pos = vecs.length - 1;
        if (toAdd > 0) {
            vecs[pos].setLength(vecs[pos].length() + toAdd);
        } else// decreasing
        {
            while (Math.abs(toAdd) >= vecs[pos].length()) {
                if (vecs[pos].nnz() > 0)
                    throw new RuntimeException("Can't decrease the length of this vector from " + length() + " to " + length + " due to non-zero value");
                toAdd += vecs[pos--].length();
            }

            // if we can't do this, it will err at us
            vecs[pos].setLength(vecs[pos].length() + toAdd);
            vecs = Arrays.copyOf(vecs, pos + 1);

        }
    }

}
